Skip to content

feat: allow Tensor.store API to receive .var as value #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 21, 2025

Conversation

aqjune-aws
Copy link
Collaborator

This patch allows Tensor.store API (which is connected to nki.language.store) to accept a more generic Core.Value type.

The motivation is tracing of interop/test/examples/matmul.py, specifically the nki_matmul_basic_ function.

After apply #111, tracing the Python function was raising the following error message:

error:
line 44:
  nl.store(result[i_out_p, i_out_f], value=result_sbuf)
  ^-- expecting tensor access

It is because its value keyword argument was having the following expression:

KLR.Trace.Term.expr (KLR.Core.Expr.value (KLR.Core.Value.var "5")) (KLR.Trace.TermType.obj `object)

which could not be converted to Access through the FromNKI typeclass.

The "5" temporary variable was emerging from the right hand side of the definition of result_sbuf:

result_sbuf = nl.copy(result_psum, dtype=result.dtype)

To convert the value of "5", it seems we need to get the generated trace and find assignment to "5" because:

def RValue : Term -> Trace Term
...
  | .expr e@(.call ..) ty => do
       let v := (<- genName).toString
       add_stmt (.assign v e)
       return .expr (.value $ .var v) ty

the add_stmt is just adding a Core statement to State.body.

Skimming through State.body and finding this assignment to "5" didn't seem something we wanted to do inside Tensor.store, so instead I slightly chose a conservative approach and simply removed the shape checker.

But any other reasonable option is still fine with me.

@aqjune-aws
Copy link
Collaborator Author

After this patch and #111, the nki_matmul_basic_ function can finally be fully traced! :)

@govereau
Copy link
Collaborator

govereau commented May 7, 2025

I think a better approach would be to simplify the variable by looking it up in the environment, as the resulting store (with the variable) doesn't correspond to anything the HW can do. At the KLR Core level, we should not have anything that does not have a corresponding ISA (or BIR) representation.

However, the problem here is more fundamental. The core issue is that matmul is not referentially transparent, and it doesn't make sense to transform it the way we are doing. This statement:

result_psum = nl.matmul(lhs_tile, rhs_tile, transpose_x=True)

is non-sensical because matmul takes the left-hand side as an input (and output) argument. This is why there is a lot of discussion about "you must write += with matmul, etc." The matmul functions need to be changed to something like:

nl.matmul(dst=result_psum, lhs_tile, rhs_tile, transpose_x=True, accum_mode=zero)

The above "statement form" would trace with no issues.

Of course, there are other operators, like tensor_tensor or tensor_scalar which would also have the issue you spotted here. However, the current thinking is that all ISA functions must be statements and not return any values. If we go this way, then we will remove store from KLR as it will not be needed. In fact, the whole add_stmt mechanism could be removed.

Copy link
Collaborator

@govereau govereau left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Offline discussion: we decided to make this change and fix the matmul/store "weirdness" in a later patch.

This patch allows Tensor.store API (which is connected to `nki.language.store`) to accept a more
generic `Core.Value` type.

The motivation is tracing of interop/test/examples/matmul.py, specifically the `nki_matmul_basic_` function.

After apply leanprover#111, tracing the Python function was raising the following error message:

```
error:
line 44:
  nl.store(result[i_out_p, i_out_f], value=result_sbuf)
  ^-- expecting tensor access
```

It is because its `value` keyword argument was having the following expression:
```
KLR.Trace.Term.expr (KLR.Core.Expr.value (KLR.Core.Value.var "5")) (KLR.Trace.TermType.obj `object)
```
which could not be converted to Access through the FromNKI typeclass.

The "5" temporary variable was emerging from the right hand side of the definition of `result_sbuf`:

```
result_sbuf = nl.copy(result_psum, dtype=result.dtype)
```

To convert the value of "5", it seems we need to get the generated trace and find assignment to "5"
because:

```
def RValue : Term -> Trace Term
...
  | .expr e@(.call ..) ty => do
       let v := (<- genName).toString
       add_stmt (.assign v e)
       return .expr (.value $ .var v) ty
```

the `add_stmt` is just adding a Core statement to `State.body`.

Skimming through `State.body` and finding this assignment to "5" didn't seem something we wanted to do inside Tensor.store,
so instead I slightly chose a conservative approach and simply removed the shape checker.

But any other reasonable option is still fine with me.
@aqjune-aws aqjune-aws changed the title Allow Tensor.store API to receive .var as value feat: allow Tensor.store API to receive .var as value May 21, 2025
@aqjune-aws aqjune-aws merged commit 3910e27 into leanprover:main May 21, 2025
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants